import torch
import torch.nn as nn
import torch.nn.functional as F


class AttentionConv1d(nn.Module):
    '''
    Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
    '''

    def __init__(self, kernel_size, out_channels):
        super(AttentionConv1d, self).__init__()
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.cosine_similarity = nn.CosineSimilarity(dim=1)

    def calculate_similarity(self, embedding, embedding_neighbor):
        similarity = self.cosine_similarity(embedding, embedding_neighbor)
        similarity = torch.unsqueeze(similarity, dim=1)
        return similarity

    def cal_local_attenttion(self, embedding, feature, kernel_size):
        # embedding for t-1
        embedding_l = torch.zeros_like(embedding)
        embedding_l[:, :, 1:] = embedding[:, :, :-1]
        similarity_l = self.calculate_similarity(embedding, embedding_l)

        # itself
        similarity_c = self.calculate_similarity(embedding, embedding)

        # embedding for t+1
        embedding_r = torch.zeros_like(embedding)
        embedding_r[:, :, :-1] = embedding[:, :, 1:]
        similarity_r = self.calculate_similarity(embedding, embedding_r)

        similarity = torch.cat([similarity_l, similarity_c, similarity_r], dim=1)  # [B, 3, T]

        # expand for D times
        batch, channel, temporal_length = feature.size()
        similarity_tile = torch.zeros(batch, kernel_size * channel, temporal_length).type_as(feature)
        similarity_tile[:, :channel * 1, :] = similarity[:, :1, :]
        similarity_tile[:, channel * 1:channel * 2, :] = similarity[:, 1:2, :]
        similarity_tile[:, channel * 2:, :] = similarity[:, 2:, :]

        return similarity_tile

    def forward(self, feature, embedding, weight):
        batch, channel, temporal_length = feature.size()
        inp = torch.unsqueeze(feature, dim=3)
        w = torch.unsqueeze(weight, dim=3)

        unfold = nn.Unfold(kernel_size=(self.kernel_size, 1), stride=1, padding=[1, 0])
        inp_unf = unfold(inp)
        # local attention
        attention = self.cal_local_attenttion(embedding, feature, kernel_size=self.kernel_size)
        inp_weight = inp_unf * attention
        inp_unf_t = inp_weight.transpose(1, 2)
        w_t = w.view(w.size(0), -1).t()
        results = torch.matmul(inp_unf_t, w_t)
        out_unf = results.transpose(1, 2)
        out = out_unf.view(batch, self.out_channels, temporal_length)
        return out


class FilterModule(nn.Module):
    def __init__(self, len_feature):
        super(FilterModule, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv1d(in_channels=len_feature, out_channels=512, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv1d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.conv_1(x)
        out = self.conv_2(out)
        return out


class EmbeddingModule(nn.Module):
    def __init__(self, len_feature):
        super(EmbeddingModule, self).__init__()
        self.conv_1 = nn.Conv1d(in_channels=len_feature, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv1d(in_channels=512, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        out = self.lrelu(self.conv_1(x))
        out = self.conv_2(out)
        embedding = F.normalize(out, p=2, dim=1)
        return embedding


class BaseModule(nn.Module):
    def __init__(self, len_feature):
        super(BaseModule, self).__init__()
        self.conv_1 = nn.Conv1d(in_channels=len_feature, out_channels=2048, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv_1_att = AttentionConv1d(kernel_size=3, out_channels=2048)
        self.conv_2 = nn.Conv1d(in_channels=2048, out_channels=2048, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv_2_att = AttentionConv1d(kernel_size=3, out_channels=2048)
        self.lrelu = nn.LeakyReLU()
        self.drop_out = nn.Dropout(0.7)

    def forward(self, x, embedding):
        feat1 = self.lrelu(self.conv_1_att(x, embedding, self.conv_1.weight))
        feat2 = self.lrelu(self.conv_2_att(feat1, embedding, self.conv_2.weight))
        feature = self.drop_out(feat2)
        return feat1, feature


class Cls_Module(nn.Module):
    def __init__(self, len_feature, num_classes):
        super(Cls_Module, self).__init__()
        self.len_feature = len_feature
        self.conv_1 = nn.Sequential(
            nn.Conv1d(in_channels=self.len_feature, out_channels=2048, kernel_size=3,
                      stride=1, padding=1),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Conv1d(in_channels=2048, out_channels=num_classes+1, kernel_size=1,
                      stride=1, padding=0, bias=False)
        )
        self.drop_out = nn.Dropout(p=0.7)

    def forward(self, x):
        # x: (B, T, F)
        out = x.permute(0, 2, 1)
        # out: (B, F, T)
        out = self.conv_1(out.float())
        feat = out.permute(0, 2, 1)
        out = self.drop_out(out)
        cas = self.classifier(out)

        cas = cas.permute(0, 2, 1)
        # out: (B, T, C + 1)
        return feat, cas


class ClassifierModule(nn.Module):
    def __init__(self, len_feature, num_classes):
        super(ClassifierModule, self).__init__()
        self.len_feature = len_feature
        self.conv_1 = nn.Sequential(
            nn.Conv1d(in_channels=self.len_feature, out_channels=2048, kernel_size=3,
                      stride=1, padding=1),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Conv1d(in_channels=2048, out_channels=num_classes+1, kernel_size=1,
                      stride=1, padding=0, bias=False)
        )
        self.drop_out = nn.Dropout(p=0.7)

    def forward(self, x):
        # out: (B, F, T)
        out = self.conv_1(x.float())
        feat = out.permute(0, 2, 1)
        out = self.drop_out(out)
        cas = self.classifier(out)
        # cas = cas.permute(0, 2, 1)
        # out: (B, T, C + 1)
        return feat, cas


class Model(nn.Module):
    def __init__(self, len_feature, num_classes, r_act):
        super(Model, self).__init__()
        self.len_feature = len_feature
        self.num_classes = num_classes
        self.r_act = r_act

        self.cls_module = Cls_Module(len_feature, num_classes)
        self.sigmoid = nn.Sigmoid()
        self.filter_module = FilterModule(len_feature)
        self.embedding_module = EmbeddingModule(len_feature)
        self.base_module = BaseModule(len_feature)
        self.classifier = ClassifierModule(len_feature, num_classes)

    def forward(self, x, vid_labels=None):
        num_segments = x.shape[1]
        k_act = num_segments // self.r_act
        x = x.float()
        # -----------------------
        x = x.permute(0, 2, 1)
        fore_weights = self.filter_module(x)
        x_supp = fore_weights * x
        embedding = self.embedding_module(x)
        _, feature_base_2 = self.base_module(x, embedding)
        feat_base, cas_base = self.classifier(feature_base_2)

        feature_supp_1, feature_supp_2 = self.base_module(x_supp, embedding)
        _, cas_supp = self.classifier(feature_supp_2)
        score_base = torch.mean(torch.topk(cas_base, k_act, dim=2)[0], dim=2)
        score_supp = torch.mean(torch.topk(cas_supp, k_act, dim=2)[0], dim=2)

        cas_base = cas_base.permute(0, 2, 1)
        cas_supp = cas_supp.permute(0, 2, 1)
        cas_sigmoid = self.sigmoid(cas_base)
        # --------------------------

        # features, cas = self.cls_module(x)
        #
        # cas_sigmoid = self.sigmoid(cas)

        cas_sigmoid_fuse = cas_sigmoid[:,:,:-1] * (1 - cas_sigmoid[:,:,-1].unsqueeze(2))
        cas_sigmoid_fuse = torch.cat((cas_sigmoid_fuse, cas_sigmoid[:,:,-1].unsqueeze(2)), dim=2)

        value, _ = cas_sigmoid.sort(descending=True, dim=1)
        topk_scores = value[:,:k_act,:-1]

        if vid_labels is None:
            vid_score = torch.mean(topk_scores, dim=1)
        else:
            vid_score = (torch.mean(topk_scores, dim=1) * vid_labels) + (torch.mean(cas_sigmoid[:,:,:-1], dim=1) * (1 - vid_labels))

        return vid_score, cas_sigmoid_fuse, feat_base, [score_base, cas_base.permute(0, 2, 1), score_supp, cas_supp.permute(0, 2, 1), embedding, fore_weights]
